import sys  # 需要使用命令行参数
import codecs  # 需要打开utf-8编码的文件
import sqlite3  # 连接SQLite3
from codecs import StreamReaderWriter
from typing import Union, TextIO
from matplotlib import pyplot
from sqlite3.dbapi2 import Cursor, Connection
from mglearn import discrete_scatter
import numpy  # 使用numpy.array来装数据
# from sklearn.cluster import KMeans  # 使用KMeans处理numpy.array所装的数据
# from sklearn.metrics import calinski_harabasz_score  # , silhouette_score
# from sklearn.decomposition import PCA
# import math
import ctypes
import platform
from mpl_toolkits.mplot3d import Axes3D  # Used for 3D painting


def main():
    # 解析命令行参数
    argv = sys.argv  # 命令行参数列表
    argv = [argv[0], 'train', 'test.db', 'allinone.script', 'company.checkpoint']
    argc = len(argv)  # 命令行参数的个数
    db = Database()
    if argc >= 2:
        if argv[1] == 'train':
            db.train()
        elif argv[1] == 'predict':
            db.predict()
    if argc >= 3:
        db.open(argv[2])
    if argc == 3:
        print('Interactive Mode')
        try:
            while True:
                line = input('>>> ')
                db.execute(line)
        except EOFError:
            exit(0)
    if argc >= 5:
        db.metastore(argv[4])
    if argc >= 4:
        db.executeScript(argv[3])


class Database:
    checkpoint: Union[TextIO, StreamReaderWriter]
    db: Connection = None
    cursor: Cursor = None

    def __init__(self, filenameDB=None, filenameScript=None, filenameCheckpoint=None):
        self.id = '编号'
        self.ref = None
        self.eval = None
        self.checkpoint = None
        self.mode = 'train'
        if filenameDB is not None:
            self.open(filenameDB)
        if filenameScript is not None:
            self.executeScript(filenameScript)
        if filenameCheckpoint is not None:
            self.metastore(filenameCheckpoint)
        if platform.system() == "Windows":
            if '32' in platform.architecture()[0]:
                self.lib = ctypes.cdll.LoadLibrary('dbkit_x86.dll')
            else:
                self.lib = ctypes.cdll.LoadLibrary('dbkit_x64.dll')
        elif platform.system() == "Linux":
            if '32' in platform.architecture()[0]:
                self.lib = ctypes.cdll.LoadLibrary('./dbkit_x86.so')
            else:
                self.lib = ctypes.cdll.LoadLibrary('./dbkit_x64.so')

    def __del__(self):
        pyplot.show()
        self.close()

    # Connection to database
    def open(self, filenameDB):
        self.close()
        self.db = sqlite3.connect(filenameDB, check_same_thread=False)
        self.db.execute("PRAGMA synchronous=OFF")
        self.db.execute("PRAGMA journal_mode=OFF")
        self.cursor = self.db.cursor()

    def close(self):
        if self.cursor is not None:
            self.cursor.close()
            self.cursor = None
        if self.db is not None:
            self.db.close()
            self.db = None
        if self.checkpoint is not None:
            self.checkpoint.close()

    def train(self):
        self.mode = 'train'

    def predict(self):
        self.mode = 'predict'

    # Execute command(s)
    def execute(self, line):
        command, arguments = self.interpretCommand(line.strip())
        if command[0] == '#': # Done
            return
        elif command in 'key': # Done
            assert len(arguments) == 1
            self.key(*arguments)
        elif command in 'copy':
            assert len(arguments) >= 2
            self.copy(*arguments[0:2], arguments[2:])
        elif command in 'cluster':
            assert 2 <= len(arguments) <= 3
            self.cluster(*arguments)
        elif command in 'rate':
            assert 2 <= len(arguments) <= 4
            self.rate(*arguments)
        elif command in 'count':
            assert len(arguments) == 2
            self.count(*arguments)
        elif command in 'drop':
            assert len(arguments) == 1
            self.drop(*arguments)
        elif command in 'sql':
            assert len(arguments) == 1
            self.sql(*arguments)
        elif command in "reference": # Done
            assert len(arguments) == 1
            self.reference(*arguments)
        elif command in "evaluate": # Done
            assert 0 <= len(arguments) <= 1
            self.evaluate(*arguments)
        elif command in "metastore": # Done
            assert len(arguments) == 1
            self.metastore(*arguments)
        elif command in 'insert':
            assert len(arguments) >= 3
            self.insert(*arguments)
        elif command in 'clean':
            assert len(arguments) == 1
            self.clean(*arguments)
        elif command in 'visualize':
            assert 1 <= len(arguments) <= 2
            self.visualize(*arguments)
        elif command in 'open': # Done
            assert len(arguments) == 1
            self.open(*arguments)
        elif command in 'close': # Done
            self.close()
        elif command in 'executescript': # Done
            assert len(arguments) == 1
            self.executeScript(*arguments)
        elif command in 'quit' or command in 'exit' or command in 'done':
            exit(0)
        else:
            print("[ERROR] " + command + " is not a valid command!")

    def executeScript(self, filenameScript):
        print("_____________________________")
        print("|                           /")
        print("| Let's begin.             /")
        print("| ________________________/")
        print("| |")
        fileScript = codecs.open(filenameScript, 'r', 'utf-8')  # 以utf-8编码打开配制文件
        lines = [line.strip() for line in fileScript.readlines()]  # 获取文件所有内容
        fileScript.close()  # 有始有终
        for line in lines:
            if len(line) == 0 or line[0] == '#':
                continue
            print("| |  Start processing: \"" + line + "\"")
            self.execute(line)
            print("| |  Finished!")
        print("| |________________________")
        print("|                          \\")
        print("| Well done!                \\")
        print("|____________________________\\")

    # Commands
    def key(self, id=None):
        self.id = id

    def reference(self, ref=None):
        self.ref = ref

    def evaluate(self, tableEval=None):
        self.eval = tableEval

    def metastore(self, filenameCheckpoint=None):
        if self.mode == 'train':
            self.checkpoint = codecs.open(filenameCheckpoint, 'w', 'utf-8')
        elif self.mode == 'predict':
            self.checkpoint = codecs.open(filenameCheckpoint, 'r', 'utf-8')
        else:
            self.checkpoint = codecs.open(filenameCheckpoint, 'rw', 'utf-8')

    def copy(self, tableDst, tableSrc, fields=None):
        if self.mode == 'predict':
            tableSrc += '_predict'
            tableDst += '_predict'
        if tableSrc not in self.getTables():
            return
        if fields is None or len(fields) == 0:
            fields = self.getFields(tableSrc)[1:]
            typenames = self.getTypenames(tableSrc)[1:]
        else:
            allFields = self.getFields(tableSrc)
            tmpFields = []
            for field in fields:
                if field in allFields:
                    tmpFields.append(field)
            fields = tmpFields
            if len(fields) == 0:
                return
            idx = {allFields[i]: i for i in range(len(allFields))}
            allTypenames = self.getTypenames(tableSrc)
            typenames = [allTypenames[idx[field]] for field in fields]
        self.createTable(tableDst)
        IDs = self.tupleList(self.getRows(tableSrc))
        self.addRows(tableDst, IDs)
        for i in range(len(fields)):
            self.addField(tableDst, fields[i], typenames[i])
            self.copyCells(tableDst, tableSrc, fields[i])
            self.db.commit()

    def clean(self, table):
        if self.mode == 'predict':
            table += '_predict'
        fields = self.getFields(table)
        for field in fields:
            self.deleteNull(table, field)
        self.db.commit()

    def cluster(self, tableDst, tableSrc, tableWeight=None):
        if self.mode == 'predict':
            tableSrc += '_predict'
            tableDst += '_predict'
        if tableSrc not in self.getTables():
            if self.mode == 'predict':
                for i in range(3):
                    self.checkpoint.readline()
            return
        target = self.interpretName(tableSrc)[0]
        tags = self.interpretName(tableSrc)[2]
        data = self.getCells(tableSrc)
        nRow = len(data)  # 记录（一条数据）个数
        if nRow == 0:
            if self.mode == 'predict':
                for i in range(3):
                    self.checkpoint.readline()
            return
        nCol = len(data[0])  # 字段个数
        if nCol == 0:
            return
        IDs, data = self.splitVertical(data)
        rowOf = {}
        for i in range(nRow):
            rowOf[IDs[i][0]] = i
        if '/' in tags:
            tags = tags.split('/')  # 把tags沿割引线“/”切成一段段的，装进tags里
        else:
            tags = tags.split('、')
        nTags = len(tags)  # 记下标签个数，因为这将决定簇的个数
        if tableWeight is None:
            w = [1] * nCol
        else:
            fields = self.getFields(tableSrc)[1:]
            nField = len(fields)
            weights = self.getCells(tableWeight)
            weightOf = {row[0]: row[1] for row in weights}
            w = [weightOf[fields[i]] if fields[i] in weightOf else float(1) for i in range(nField)]
        # 核心的分类打标签算法开始运转 {
        X = numpy.array(data)
        iTags, centers, evaluation = self.assortStdNormCkmeans_1d_dp(X, nTags, w)
        sorter = []  # 由于这个K-Means的结果中并非簇内数据越大标签编号越大，所以我们这里得手动排个序，用有序对（中心点，标签编号）来排序。
        for i in range(nTags):  # 对于标签编号i
            sorter.append((centers[i], i))  # 然后构建有序对（中心点权重，标签编号），加到列表中
        sorter.sort()  # 按中心点权重进行从小到大的排序。（权重相同，则按标签编号排序）
        tTags = [''] * nTags  # 接下来要做的是一个PBox，暂时用tTags来存结果。
        for i in range(nTags):  # 对于从小到大第i个标签（注意了，这里约定，配置文件中所给出的标签是按从小到大排列的，如“低/中/高”，而不是“大/中/小”）
            tTags[sorter[i][1]] = tags[i]  # 令它成为按中心点从小到大排序的第i个标签编号所对应的标签
        tags = tTags  # PBox完成，将结果覆写到原标签列表中
        # } 执行完毕
        # 填入数据库
        self.createTable(tableDst)
        self.addField(tableDst, target)
        self.addRows(tableDst, IDs)
        self.fillCells(tableDst, IDs, target, [(tags[iTag],) for iTag in iTags])
        if self.eval is not None and len(evaluation) > 0:
            self.createTable(self.eval, '聚类目标')
            self.addRow(self.eval, target)
            for field in evaluation:
                self.addField(self.eval, field, 'DOUBLE')
                self.fillCell(self.eval, target, field, evaluation[field])
        self.db.commit()  # 实验证明，如果不commit，那么上面那些SQL命令并不会对数据库造成修改。执行commit以完成数据库的更新。

    def rate(self, tableDst, tableSrc, tableCriteria=None, target=None):
        if target is None:
            target = self.interpretName(tableSrc)[2]
        if tableCriteria is None:
            keywords = self.interpretName(tableSrc)
            tableCriteria = keywords[0] + "（标准：" + keywords[2] + "）"
        if self.mode == 'predict':
            tableSrc += '_predict'
            tableDst += '_predict'
        if tableSrc not in self.getTables():
            return
        data = self.getCells(tableSrc)
        nRow = len(data)  # 记录（一条数据）个数
        if nRow == 0:
            return
        IDs, data = self.splitVertical(data)
        nCol = len(data[0])  # 字段个数
        if nCol == 0:
            return
        rowOf = {}
        for i in range(nRow):
            rowOf[IDs[i][0]] = i
        toRate = self.getCells(tableSrc)
        assert len(toRate) == len(IDs)
        rateFields = self.getFields(tableSrc)
        nRateFields = len(rateFields)
        criteria = self.getCells(tableCriteria)
        score = {}
        for row in criteria:
            if row[0] not in score:
                score[row[0]] = {}
            score[row[0]][row[1]] = row[2]
        ratings = [0] * nRow
        for row in toRate:
            iRow = rowOf[row[0]]
            for i in range(1, nRateFields):
                if row[i] is not None and rateFields[i] in score and row[i] in score[rateFields[i]]:
                    ratings[iRow] += score[rateFields[i]][row[i]]
        self.createTable(tableDst)
        self.addField(tableDst, target, 'DOUBLE')
        self.addRows(tableDst, IDs)
        self.fillCells(tableDst, IDs, target, self.tupleList(ratings))
        self.db.commit()

    def count(self, tableDst, tableSrc):
        if self.mode == 'predict':
            return
        if tableSrc not in self.getTables():
            return
        target = self.interpretName(tableSrc)[0]
        data = self.getCells(tableSrc)
        IDs, data = self.splitVertical(data)
        nRow = len(data)  # 记录（一条数据）个数
        assert nRow > 0  # 我可不想考虑记录个数为0的边界情况
        nCol = len(data[0])  # 字段个数
        assert nCol > 0  # 也不想考虑字段个数为0的无理情况
        cnt = [0] * nRow
        for i in range(nRow):
            for cell in data[i]:
                if cell is not None:
                    cnt[i] += 1
        self.createTable(tableDst)
        self.addField(tableDst, target, 'INTEGER')
        self.addRows(tableDst, IDs)
        self.fillCells(tableDst, IDs, target, self.tupleList(cnt))
        self.db.commit()

    def visualize(self, table, savepath=None):
        if self.mode == 'predict':
            table += '_predict'
        if table not in self.getTables():
            return
        pyplot.rcParams['font.sans-serif'] = ['Microsoft YaHei']
        tags = self.interpretName(table)[2]
        data = self.getCells(table)
        nRow = len(data)  # 记录（一条数据）个数
        if '/' in tags:
            tags = tags.split('/')  # 把tag沿割引线“/”切成一段段的，装进labels里
        else:
            tags = tags.split('、')
        tags.reverse()  # 将labels从大到小排序
        nTag = len(tags)  # 记下标签个数
        rankTag = {}
        for i in range(nTag):
            rankTag[tags[i]] = i
        X = numpy.array([row[1:] for row in data])  # 字段值加标签
        X_num = len(X[0]) - 1  # 这是数据的维数,默认最后一列是标签
        iTags = [''] * nRow
        for i in range(nRow):
            for j in range(nTag):
                if X[i][X_num] == tags[j]:
                    iTags[i] = j
                    break
        columns = self.getFields(table)
        c = ['red', 'orange', 'yellow', 'green', 'cyan', 'blue', 'purple', 'pink']
        X_test = numpy.array([nTag - rankTag[X[i][-1]] for i in range(nRow)])
        mX = {}
        if X_num == 1:
            pyplot.figure()
            discrete_scatter(X[:, 0].astype(float), X_test, iTags, markers=['.'], c=c)
            pyplot.xlabel(columns[1])
            pyplot.legend(tags, loc='best')
            pyplot.yticks([])
            if savepath is not None:
                savename = savepath + columns[1] + '.png'  # 名称为一个属性名
                pyplot.savefig(savename)
        elif X_num == 3:
            for i in range(nRow):
                if X[i][3] not in mX:
                    mX[X[i][3]] = []
                mX[X[i][3]].append(X[i][0:3].astype(float))
            for tag in mX:
                mX[tag] = numpy.array(mX[tag])
            pyplot.figure()
            ax = pyplot.axes(projection='3d')
            for tag in mX:
                ax.scatter(mX[tag][:, 0], mX[tag][:, 1], mX[tag][:, 2], c=c[rankTag[tag]], label=tag)
            ax.legend(loc='best')
            ax.set_xlabel(columns[1])
            ax.set_ylabel(columns[2])
            ax.set_zlabel(columns[3])
            if savepath is not None:
                savename = savepath + '+'.join(columns[1:4]) + '.png'  # 名称为一个属性名
                pyplot.savefig(savename)
        else:
            for i in range(0, X_num - 1):
                for j in range(i + 1, X_num):
                    pyplot.figure()
                    discrete_scatter(X[:, i].astype(float), X[:, j].astype(float), iTags, markers=['.'], c=c)
                    pyplot.legend(tags, loc='best')
                    pyplot.xlabel(columns[i + 1])
                    pyplot.ylabel(columns[j + 1])
                    if savepath is not None:
                        savename = savepath + columns[i + 1] + '+' + columns[j + 1] + '.png'  # 名称为两个属性名
                        pyplot.savefig(savename)

    def drop(self, tableOrView):
        if tableOrView in self.getTables():
            self.dropTable(tableOrView)
        elif tableOrView in self.getViews():
            self.dropView(tableOrView)
        else:
            return
        self.db.commit()

    def sql(self, scripts: str):
        scripts = [script.strip() for script in scripts.split(';')]
        for script in scripts:
            if len(script) == 0:
                continue
            self.cursor.execute(script)
            nAffected = self.cursor.rowcount
            if nAffected != -1:
                print(str(nAffected) + " row(s) affected.")
            data = self.cursor.fetchall()
            for row in data:
                print('\t| '.join([str(cell) for cell in row]))
        self.db.commit()

    def insert(self, ID, field, value, *tablesDst):
        if len(tablesDst) > 0:
            tables = tablesDst
        elif self.ref is not None:
            refFields = self.getFields(self.ref)
            tables = [table[0] for table in
                      self.getCells(self.ref, refFields[1], "[" + refFields[0] + "]='" + field + "'")]
        else:
            allTables = self.getTables()
            tables = []
            for table in allTables:
                if field in self.getFields(table):
                    tables.append(table)
        for table in tables:
            self.addRow(table, ID)
            self.fillCell(table, ID, field, value)
        self.db.commit()

    # Auxiliary function
    @staticmethod
    def interpretCommand(line):
        tokens = line.strip().split(':', 1)  # strip去掉首尾的空白字符，然后以冒号为分隔标识用split将该行分成两部分
        command = tokens[0].strip().lower()  # 冒号前的是head，也就是所谓的键
        if len(tokens) == 1:
            arguments = []
        elif command == 'sql':
            arguments = [tokens[1].strip()]
        else:
            arguments = tokens[1].strip().split(',')  # 冒号后的是values，即，值；这里的值是列表，并且是以逗号分隔的，因此解析时还要再split一次。
            arguments = [argument.strip() for argument in arguments]  # 这一大串的strip保证了：配制文件中的空白字符不会对内容有任何影响
        return command, arguments

    @staticmethod
    def interpretName(name):
        return name.replace('（', ':').replace('）', ':').replace('：', ':').split(':')

    @staticmethod
    def splitVertical(data, indices=None):
        # Split into [0,i1),[i1,i2)...[in:]
        if len(data) == 0:
            return [tuple(), tuple()]
        if indices is None:
            indices = [1]
        indices.insert(0, 0)
        indices.append(len(data[0]))
        result = []
        for i in range(1, len(indices)):
            l = indices[i - 1]
            r = indices[i]
            result.append([row[l:r] for row in data])
        return result

    @staticmethod
    def tupleList(data):
        if len(data) == 0:
            return data
        if type(data[0]) is list:
            return [tuple(row) for row in data]
        elif type(data[0]) is tuple:
            return data
        else:
            return [(cell,) for cell in data]

    # Operation on database
    def getTables(self):
        sqlGetTable = "select name from sqlite_master " + \
                      "where type='table';"  # 获取数据库中名为tableResult的表
        return [table[0] for table in self.cursor.execute(sqlGetTable)]  # 执行SQL语句

    def createTable(self, table, pk=None):
        sqlGetTable = "select name from sqlite_master " + \
                      "where type='table' and name='" + table + "';"  # 获取数据库中名为tableResult的表
        self.cursor.execute(sqlGetTable)  # 执行SQL语句
        if len(self.cursor.fetchall()) == 0:  # 如果发现没有找到名为tableResult的表
            sqlCreateTable = "CREATE TABLE [" + table + "]([" \
                             + (
                                 self.id if pk is None else pk) + "] TEXT PRIMARY KEY);"  # 那么就创建一个。两字段：主键名（整数型），聚类目标名（文本型）
            self.cursor.execute(sqlCreateTable)  # 执行SQL语句

    def copyTable(self, tableDst, tableSrc):
        if tableDst in self.getTables():
            self.dropTable(tableDst)
        sqlCopyTable = "create table [" + tableDst + "] as select * from [" + tableSrc + "] where 0=1"
        self.cursor.execute(sqlCopyTable)

    def dropTable(self, table):
        sqlDropTable = "drop table [" + table + "]"
        self.cursor.execute(sqlDropTable)

    def getViews(self):
        sqlGetView = "select name from sqlite_master " + \
                     "where type='view';"  # 获取数据库中名为tableResult的表
        return [view[0] for view in self.cursor.execute(sqlGetView)]  # 执行SQL语句

    def dropView(self, view):
        sqlDropView = "drop view [" + view + "]"
        self.cursor.execute(sqlDropView)

    def getFields(self, table):
        sqlGetColumns = "pragma table_info([" + table + "]);"  # 获取该表的各字段的信息
        self.cursor.execute(sqlGetColumns)  # 执行SQL语句
        columns = [column[1] for column in self.cursor.fetchall()]
        return columns  # 获取各字段的信息

    def getTypenames(self, table):
        sqlGetColumns = "pragma table_info([" + table + "]);"  # 获取该表的各字段的信息
        self.cursor.execute(sqlGetColumns)  # 执行SQL语句
        return [column[2] for column in self.cursor.fetchall()]  # 获取各字段的信息

    def addField(self, table, field, typename='TEXT'):
        columns = self.getFields(table)
        if field not in columns:
            sqlAddColumn = "alter table [" + table + "] add column [" + field + "] " + typename + ";"  # 那么就创建一个。
            self.cursor.execute(sqlAddColumn)  # 执行SQL语句

    def getRows(self, table):
        sqlSelect = "select [" + self.id + "] from [" + table + "]"
        return [line[0] for line in self.cursor.execute(sqlSelect)]

    def addRow(self, table, ID):
        sqlSelect = "select [" + self.getId(table) \
                    + "] from [" + table \
                    + "] where [" + self.getId(table) + "]='" + ID + "'"
        self.cursor.execute(sqlSelect)
        if len(self.cursor.fetchall()) > 0:
            return
        sqlInsertValues = "insert into [" + table + "] ([" + self.getId(table) + "]) values('" + ID + "');"
        self.cursor.execute(sqlInsertValues)

    def addRows(self, table, IDs):
        sqlSelect = "select [" + self.id + "] from [" + table + "]"
        existingIDs = set([line[0] for line in self.cursor.execute(sqlSelect)])
        pendingIDs = set([str(ID[0]) for ID in IDs])
        pendingIDs.difference_update(existingIDs)
        newIDs = [tuple([str(ID)]) for ID in pendingIDs]
        sqlInsertValues = "insert into [" + table + "] ([" + self.id + "]) values(?);"
        self.cursor.executemany(sqlInsertValues, newIDs)

    def fillCell(self, table, ID, field, value):
        self.fillCells(table, [(ID,)], field, [(value,)])

    def getId(self, table):
        return self.getFields(table)[0]

    def fillCells(self, table, IDs, field, values):
        nRow = len(IDs)
        sqlUpdateValues = "update [" + table + "] set [" + field + "]=?" \
                          + " where [" + self.getId(table) + "]=? ;"
        values = [tuple(values[i] + IDs[i]) for i in range(nRow)]
        self.cursor.executemany(sqlUpdateValues, values)

    def getCells(self, table, fields=None, expr=None):
        if fields is not None and type(fields) is not list:
            fields = [fields]
        sqlGetData = "select " + ("*" if fields is None else "[" + "],[".join(fields) + "]") \
                     + " from [" + table + "]" + ("" if expr is None else " where " + expr)
        self.cursor.execute(sqlGetData)  # 执行SQL语句
        return self.cursor.fetchall()  # 获取表中的全部数据

    def copyCells(self, tableDst, tableSrc, field):
        # update A set b=(select b from B where B.id=A.id)
        # where id=(select id from B where B.id=A.id)
        # sqlCopy = "update [" + tableDst + "] set [" + field + "]=(" \
        #           + "select [" + field + "] from [" + tableSrc \
        #           + "] where [" + tableSrc + "].[" + self.id + "]=[" + tableDst + "].[" + self.id + "]) " \
        #           + "where [" + self.id + "]=(" \
        #           + "select [" + self.id + "] from [" + tableSrc \
        #           + "] where [" + tableSrc + "].[" + self.id + "]=[" + tableDst + "].[" + self.id + "]);"
        # self.cursor.execute(sqlCopy)  # Too slow!
        data = self.getCells(tableSrc, [self.id, field])
        IDs, data = self.splitVertical(data)
        self.fillCells(tableDst, IDs, field, data)

    def deleteCells(self, table, expr=None):
        sqlDelete = "delete from [" + table + "]" + ((" where " + expr) if expr is not None else "")
        self.cursor.execute(sqlDelete)

    def deleteNull(self, table, field):
        self.deleteCells(table, "[" + field + "] is NULL")

    # Mathematical method
    @staticmethod
    def normalize(x, order=2, w=None):
        dim = len(x)
        if w is None:
            w = [1] * dim
        sumX = float(0)
        sumW = float(0)
        for i in range(dim):
            if x[i] is not None:
                sumX += pow(x[i], order) * w[i]
                sumW += w[i]
        return sumX / sumW if sumW != 0 else sumX  # works even when w is negative

    def assortStdNormCkmeans_1d_dp(self, X, nCluster, w):
        # Std stands for Standardization: map [L,R] to [0.0,1.0]
        # Norm stands for Normalization 2: Square root of sum of squares
        # KMeans operates on 1-dimensional norm2's
        nRow = len(X)
        assert nRow > 0
        nCol = len(X[0])
        if self.mode == 'predict':
            L = [float(x) for x in self.checkpoint.readline().strip().split(' ')]
            D = [float(x) for x in self.checkpoint.readline().strip().split(' ')]
        else:
            L = X[0].copy()
            R = X[0].copy()
            for row in X:
                for i in range(nCol):
                    if row[i] is not None:
                        L[i] = row[i] if L[i] is None else min(L[i], row[i])
                        R[i] = row[i] if R[i] is None else max(R[i], row[i])
            D = [R[i] - L[i] for i in range(nCol)]
            if self.checkpoint is not None:
                self.checkpoint.write(' '.join(["%.10f" % x for x in L]) + '\n')
                self.checkpoint.write(' '.join(["%.10f" % x for x in D]) + '\n')
        if X.dtype == 'int64' or X.dtype == 'int32':
            X = X.astype('float64')
        for row in X:
            for i in range(nCol):
                if row[i] is not None:
                    row[i] = float(row[i] - L[i]) / D[i] if D[i] != 0 else float(1)
        n = nRow
        x = (ctypes.c_double * n)()
        for i in range(n):
            x[i] = self.normalize(X[i], 1, w)
        X = [x[i] for i in range(n)]
        if self.mode == 'predict':
            centers = [float(center) for center in self.checkpoint.readline().strip().split(' ')]
            labels = [None] * nRow
            for i in range(nRow):
                labels[i] = 0
                for j in range(nCluster):
                    if abs(centers[j] - x[i]) < abs(centers[labels[i]] - x[i]):
                        labels[i] = j
            evaluation = {}
        else:
            y = (ctypes.c_int * n)()
            w = ctypes.c_void_p(0)
            minK = nCluster
            maxK = nCluster
            centers = (ctypes.c_double * maxK)()
            withinss = (ctypes.c_double * maxK)()
            size = (ctypes.c_double * maxK)()
            BICs = (ctypes.c_double * maxK)()
            self.lib.Ckmeans_1d_dp(x, n, minK, maxK, w, y, centers, size, withinss, BICs)
            if self.checkpoint is not None:
                self.checkpoint.write(' '.join(["%.10f" % centers[i] for i in range(nCluster)]) + '\n')
            labels = [y[i] for i in range(n)]
            if self.eval is None:
                return labels, centers, None
            evaluation = self.getEvaluation(x, y, n, nCluster, centers, size)
        return labels, centers, evaluation

    def getEvaluation(self, x, y, n, nCluster, centers, size):
        n = ctypes.c_int(n)
        nCluster = ctypes.c_int32(nCluster)
        nScore = ctypes.c_int32()
        pScore = ctypes.POINTER(ctypes.c_double)()
        pMethod = ctypes.POINTER(ctypes.c_char_p)()
        self.lib.evaluate(x, y, n, nCluster, centers, size,
                          ctypes.byref(nScore), ctypes.byref(pScore), ctypes.byref(pMethod))
        evaluation = {}
        for i in range(nScore.value):
            evaluation[pMethod[i].decode("utf-8")] = pScore[i]
        return evaluation


if __name__ == "__main__":
    main()
